{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Link Prediction using Graph Neural Networks\n",
    "\n",
    "GNNs are powerful tools for many machine learning tasks on graphs. This tutorial teaches the basic workflow of using GNNs for link prediction. We again use the Zachery's Karate Club graph but try to predict interactions between two members.\n",
    "\n",
    "In this tutorial, you will learn:\n",
    "* Prepare training and testing sets for link prediction task.\n",
    "* Build a GNN-based link prediction model.\n",
    "* Train the model and verify the result."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tutorial_utils import setup_tf\n",
    "setup_tf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dgl\n",
    "import tensorflow as tf\n",
    "import itertools\n",
    "import numpy as np\n",
    "import scipy.sparse as sp"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load graph and features\n",
    "\n",
    "Following the last [session](./2_gnn.ipynb), we first load the Zachery's Karate Club graph and creates node embeddings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tutorial_utils import load_zachery\n",
    "\n",
    "# ----------- 0. load graph -------------- #\n",
    "g = load_zachery()\n",
    "print(g)\n",
    "\n",
    "# ----------- 1. node features -------------- #\n",
    "node_embed = tf.keras.layers.Embedding(g.number_of_nodes(), 5,\n",
    "                                       embeddings_initializer='glorot_uniform')  # Every node has an embedding of size 5.\n",
    "node_embed(1) # intialize embedding layer\n",
    "inputs = node_embed.embeddings # # Use the embedding weight as the node features.\n",
    "print(inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare training and testing sets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In general, a link prediction data set contains two types of edges, *positive* and *negative edges*. Positive edges are usually drawn from the existing edges in the graph. In this example, we randomly pick 50 edges for testing and leave the rest for training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split edge set for training and testing\n",
    "u, v = g.edges()\n",
    "u, v = u.numpy(), v.numpy()\n",
    "eids = np.arange(g.number_of_edges())\n",
    "eids = np.random.permutation(eids)\n",
    "test_pos_u, test_pos_v = u[eids[:50]], v[eids[:50]]\n",
    "train_pos_u, train_pos_v = u[eids[50:]], v[eids[50:]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since the number of negative edges is large, sampling is usually desired. How to choose proper negative sampling algorithms is a widely-studied topic and is out of scope of this tutorial. Since our example graph is quite small (with only 34 nodes), we enumerate all the missing edges and randomly pick 50 for testing and 150 for training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Find all negative edges and split them for training and testing\n",
    "adj = sp.coo_matrix((np.ones(len(u)), (u, v)))\n",
    "adj_neg = 1 - adj.todense() - np.eye(34)\n",
    "neg_u, neg_v = np.where(adj_neg != 0)\n",
    "neg_eids = np.random.choice(len(neg_u), 200)\n",
    "test_neg_u, test_neg_v = neg_u[neg_eids[:50]], neg_v[neg_eids[:50]]\n",
    "train_neg_u, train_neg_v = neg_u[neg_eids[50:]], neg_v[neg_eids[50:]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Put positive and negative edges together and form training and testing sets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create training set.\n",
    "train_u = tf.concat([train_pos_u, train_neg_u], axis=0)\n",
    "train_v = tf.concat([train_pos_v, train_neg_v], axis=0)\n",
    "train_label = tf.concat([tf.zeros(len(train_pos_u)), tf.ones(len(train_neg_u))], axis=0)\n",
    "\n",
    "# Create testing set.\n",
    "test_u = tf.concat([test_pos_u, test_neg_u], axis=0)\n",
    "test_v = tf.concat([test_pos_v, test_neg_v], axis=0)\n",
    "test_label = tf.concat([tf.zeros(len(test_pos_u)), tf.ones(len(test_neg_u))], axis=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define a GraphSAGE model\n",
    "\n",
    "Our model consists of two layers, each computes new node representations by aggregating neighbor information. The equations are:\n",
    "\n",
    "$$\n",
    "h_{\\mathcal{N}(v)}^k\\leftarrow \\text{AGGREGATE}_k\\{h_u^{k-1},\\forall u\\in\\mathcal{N}(v)\\}\n",
    "$$\n",
    "\n",
    "$$\n",
    "h_v^k\\leftarrow \\text{ReLU}\\left(W^k\\cdot \\text{CONCAT}(h_v^{k-1}, h_{\\mathcal{N}(v)}^k) \\right)\n",
    "$$\n",
    "\n",
    "DGL provides implementation of many popular neighbor aggregation modules. They all can be invoked easily with one line of codes. See the full list of supported [graph convolution modules](https://docs.dgl.ai/api/python/nn.pytorch.html#module-dgl.nn.pytorch.conv)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dgl.nn import SAGEConv\n",
    "\n",
    "# ----------- 2. create model -------------- #\n",
    "# build a two-layer GraphSAGE model\n",
    "class GraphSAGE(tf.keras.layers.Layer):\n",
    "    def __init__(self, in_feats, h_feats):\n",
    "        super(GraphSAGE, self).__init__()\n",
    "        self.conv1 = SAGEConv(in_feats, h_feats, 'mean')\n",
    "        self.conv2 = SAGEConv(h_feats, h_feats, 'mean')\n",
    "    \n",
    "    def call(self, g, in_feat):\n",
    "        h = self.conv1(g, in_feat)\n",
    "        h = tf.nn.relu(h)\n",
    "        h = self.conv2(g, h)\n",
    "        return h\n",
    "    \n",
    "# Create the model with given dimensions \n",
    "# input layer dimension: 5, node embeddings\n",
    "# hidden layer dimension: 16\n",
    "net = GraphSAGE(5, 16)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We then optimize the model using the following loss function.\n",
    "\n",
    "$$\n",
    "\\hat{y}_{u\\sim v} = \\sigma(h_u^T h_v)\n",
    "$$\n",
    "\n",
    "$$\n",
    "\\mathcal{L} = -\\sum_{u\\sim v\\in \\mathcal{D}}\\left( y_{u\\sim v}\\log(\\hat{y}_{u\\sim v}) + (1-y_{u\\sim v})\\log(1-\\hat{y}_{u\\sim v})) \\right)\n",
    "$$\n",
    "\n",
    "Essentially, the model predicts a score for each edge by dot-producting the representations of its two end-points. It then computes a binary cross entropy loss with the target $y$ being 0 or 1 meaning whether the edge is a positive one or not."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----------- 3. set up loss and optimizer -------------- #\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)\n",
    "loss_fcn = tf.keras.losses.BinaryCrossentropy(\n",
    "    from_logits=False)\n",
    "\n",
    "# ----------- 4. training -------------------------------- #\n",
    "all_logits = []\n",
    "for e in range(100):\n",
    "    \n",
    "    with tf.GradientTape() as tape:\n",
    "        tape.watch(inputs) # optimize embedding layer also\n",
    "        # forward\n",
    "        logits = net(g, inputs)\n",
    "        pred = tf.sigmoid(tf.reduce_sum(tf.gather(logits, train_u) *\n",
    "                                        tf.gather(logits, train_v), axis=1))\n",
    "\n",
    "        # compute loss\n",
    "        loss = loss_fcn(train_label, pred)\n",
    "\n",
    "        # backward\n",
    "        grads = tape.gradient(loss, net.trainable_weights + node_embed.trainable_weights)        \n",
    "        optimizer.apply_gradients(zip(grads, net.trainable_weights + node_embed.trainable_weights))\n",
    "        all_logits.append(logits.numpy())\n",
    "\n",
    "    if e % 5 == 0:\n",
    "        print('In epoch {}, loss: {}'.format(e, loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----------- 5. check results ------------------------ #\n",
    "pred = tf.sigmoid(tf.reduce_sum(tf.gather(logits, test_u) *\n",
    "                                tf.gather(logits, test_v), axis=1)).numpy()\n",
    "print('Accuracy', ((pred >= 0.5) == test_label.numpy()).sum().item() / len(pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "3.6.8",
   "language": "python",
   "name": "3.6.8"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
